{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'1.0.0'"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import  torch\n",
    "import torchvision\n",
    "torch.__version__"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 4.5 多GPU并行训练\n",
    "\n",
    "在我们进行神经网络训练的时候，因为计算量巨大所以单个GPU运算会使得计算时间很长，使得我们不能够及时的得到结果，例如我们如果使用但GPU使用ImageNet的数据训练一个分类器，可能会花费一周甚至一个月的时间。所以在Pytorch中引入了多GPU计算的机制，这样使得训练速度可以指数级的增长。\n",
    "\n",
    "stanford大学的[DAWNBench](https://dawn.cs.stanford.edu/benchmark/) 就记录了目前为止的一些使用多GPU计算的记录和实现代码，有兴趣的可以看看。\n",
    "\n",
    "这章里面我们要介绍的三个方式来使用多GPU加速"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4.5.1 torch.nn.DataParalle\n",
    "一般情况下我们都会使用一台主机带多个显卡，这样是一个最节省预算的方案，在Pytorch中为我们提供了一个非常简单的方法来支持但主机多GPU，那就`torch.nn.DataParalle` 我们只要将我们自己的模型作为参数，直接传入即可，剩下的事情Pytorch都为我们做了"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "#使用内置的一个模型，我们这里以resnet50为例\n",
    "model = torchvision.models.resnet50()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "DataParallel(\n",
       "  (module): ResNet(\n",
       "    (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
       "    (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "    (relu): ReLU(inplace)\n",
       "    (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
       "    (layer1): Sequential(\n",
       "      (0): Bottleneck(\n",
       "        (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace)\n",
       "        (downsample): Sequential(\n",
       "          (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "          (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (1): Bottleneck(\n",
       "        (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace)\n",
       "      )\n",
       "      (2): Bottleneck(\n",
       "        (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace)\n",
       "      )\n",
       "    )\n",
       "    (layer2): Sequential(\n",
       "      (0): Bottleneck(\n",
       "        (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace)\n",
       "        (downsample): Sequential(\n",
       "          (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "          (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (1): Bottleneck(\n",
       "        (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace)\n",
       "      )\n",
       "      (2): Bottleneck(\n",
       "        (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace)\n",
       "      )\n",
       "      (3): Bottleneck(\n",
       "        (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace)\n",
       "      )\n",
       "    )\n",
       "    (layer3): Sequential(\n",
       "      (0): Bottleneck(\n",
       "        (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace)\n",
       "        (downsample): Sequential(\n",
       "          (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "          (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (1): Bottleneck(\n",
       "        (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace)\n",
       "      )\n",
       "      (2): Bottleneck(\n",
       "        (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace)\n",
       "      )\n",
       "      (3): Bottleneck(\n",
       "        (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace)\n",
       "      )\n",
       "      (4): Bottleneck(\n",
       "        (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace)\n",
       "      )\n",
       "      (5): Bottleneck(\n",
       "        (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace)\n",
       "      )\n",
       "    )\n",
       "    (layer4): Sequential(\n",
       "      (0): Bottleneck(\n",
       "        (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace)\n",
       "        (downsample): Sequential(\n",
       "          (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "          (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        )\n",
       "      )\n",
       "      (1): Bottleneck(\n",
       "        (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace)\n",
       "      )\n",
       "      (2): Bottleneck(\n",
       "        (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "        (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "        (relu): ReLU(inplace)\n",
       "      )\n",
       "    )\n",
       "    (avgpool): AvgPool2d(kernel_size=7, stride=1, padding=0)\n",
       "    (fc): Linear(in_features=2048, out_features=1000, bias=True)\n",
       "  )\n",
       ")"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#模型使用多GPU\n",
    "mdp = torch.nn.DataParallel(model)\n",
    "mdp"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "只要这样一个简单的包裹，Pytorch已经为我们做了很多复杂的工作。我们只需要增大我们训练的batch_size(一般计算为N倍，N为显卡数量)，其他代码不需要任何改动。\n",
    "虽然代码不需要做更改，但是batch size太大了训练收敛会很慢，所以还要把学习率调大一点。大学率也会使得模型的训练在早期的阶段变得十分不稳定，所以这里需要一个学习率的热身（warm up） 来稳定梯度的下降，然后在逐步的提高学习率。\n",
    "\n",
    "这种热身只有在超级大的批次下才需要进行，一般我们这种一机4卡或者说在batch size 小于 5000（个人测试）基本上是不需要的。例如最近富士通使用2048个GPU,74秒训练完成resnet50的实验中使用的batch size 为 81920 [arivx](http://www.arxiv.org/abs/1903.12650)这种超大的size才需要。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "DataParallel的并行处理机制是，首先将模型加载到主 GPU 上(默认的第一个GPU，GPU0为主GPU)，然后再将模型复制到各个指定的从 GPU 中，然后将输入数据按 batch 维度进行划分，具体来说就是每个 GPU 分配到的数据 batch 数量是总输入数据的 batch 除以指定 GPU 个数。每个 GPU 将针对各自的输入数据独立进行 forward 计算，最后将各个 GPU 的 loss 进行求和，再用反向传播更新单个 GPU 上的模型参数，再将更新后的模型参数复制到剩余指定的 GPU 中，这样就完成了一次迭代计算。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "DataParallel其实也是一个nn.Model所以我们可以保存权重的方法和一般的nn.Model没有区别，只不过如果你想使用单卡或者cpu作为推理的时候需要从里面读出原始的model "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "ResNet(\n",
       "  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)\n",
       "  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "  (relu): ReLU(inplace)\n",
       "  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)\n",
       "  (layer1): Sequential(\n",
       "    (0): Bottleneck(\n",
       "      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "      (downsample): Sequential(\n",
       "        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "        (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "    )\n",
       "    (1): Bottleneck(\n",
       "      (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "    )\n",
       "    (2): Bottleneck(\n",
       "      (conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "    )\n",
       "  )\n",
       "  (layer2): Sequential(\n",
       "    (0): Bottleneck(\n",
       "      (conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "      (downsample): Sequential(\n",
       "        (0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "        (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "    )\n",
       "    (1): Bottleneck(\n",
       "      (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "    )\n",
       "    (2): Bottleneck(\n",
       "      (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "    )\n",
       "    (3): Bottleneck(\n",
       "      (conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "    )\n",
       "  )\n",
       "  (layer3): Sequential(\n",
       "    (0): Bottleneck(\n",
       "      (conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "      (downsample): Sequential(\n",
       "        (0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "        (1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "    )\n",
       "    (1): Bottleneck(\n",
       "      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "    )\n",
       "    (2): Bottleneck(\n",
       "      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "    )\n",
       "    (3): Bottleneck(\n",
       "      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "    )\n",
       "    (4): Bottleneck(\n",
       "      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "    )\n",
       "    (5): Bottleneck(\n",
       "      (conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "    )\n",
       "  )\n",
       "  (layer4): Sequential(\n",
       "    (0): Bottleneck(\n",
       "      (conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "      (downsample): Sequential(\n",
       "        (0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)\n",
       "        (1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      )\n",
       "    )\n",
       "    (1): Bottleneck(\n",
       "      (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "    )\n",
       "    (2): Bottleneck(\n",
       "      (conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)\n",
       "      (bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)\n",
       "      (bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)\n",
       "      (relu): ReLU(inplace)\n",
       "    )\n",
       "  )\n",
       "  (avgpool): AvgPool2d(kernel_size=7, stride=1, padding=0)\n",
       "  (fc): Linear(in_features=2048, out_features=1000, bias=True)\n",
       ")"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "#获取到原始的model\n",
    "m=mdp.module\n",
    "m"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "DataParallel会将定义的网络模型参数默认放在GPU 0上，所以dataparallel实质是可以看做把训练参数从GPU拷贝到其他的GPU同时训练，这样会导致内存和GPU使用率出现很严重的负载不均衡现象，即GPU 0的使用内存和使用率会大大超出其他显卡的使用内存，因为在这里GPU0作为master来进行梯度的汇总和模型的更新，再将计算任务下发给其他GPU，所以他的内存和使用率会比其他的高。\n",
    "\n",
    "所以我们使用新的torch.distributed来构建更为同步的分布式运算。使用torch.distributed不仅可以支持单机还可以支持多个主机，多个GPU进行计算。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4.5.2 torch.distributed\n",
    "`torch.distributed`相对于`torch.nn.DataParalle` 是一个底层的API，所以我们要修改我们的代码，使其能够独立的在机器（节点）中运行。我们想要完全实现分布式，并且在每个结点的每个GPU上独立运行进程，这一共需要N个进程。N是我们的GPU总数，这里我们以4来计算。\n",
    "\n",
    "首先 初始化分布式后端，封装模型以及准备数据，这些数据用于在独立的数据子集中训练进程。修改后的代码如下"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# 以下脚本在jupyter notebook执行肯定会不成功，请保存成py文件后测试\n",
    "from torch.utils.data.distributed import DistributedSampler\n",
    "from torch.utils.data import DataLoader\n",
    "\n",
    "# 这里的node_rank是本地GPU的标识\n",
    "parser = argparse.ArgumentParser()\n",
    "parser.add_argument(\"--node_rank\", type=int)\n",
    "args = parser.parse_args()\n",
    "\n",
    "# 使用Nvdea的nccl来初始化节点 \n",
    "torch.distributed.init_process_group(backend='nccl')\n",
    "\n",
    "# 封装分配给当前进程的GPU上的模型\n",
    "device = torch.device('cuda', arg.local_rank)\n",
    "model = model.to(device)\n",
    "distrib_model = torch.nn.parallel.DistributedDataParallel(model,\n",
    "                                                          device_ids=[args.node_rank],\n",
    "                                                          output_device=args.node_rank)\n",
    "\n",
    "# 将数据加载限制为数据集的子集（不包括当前进程）\n",
    "sampler = DistributedSampler(dataset)\n",
    "\n",
    "dataloader = DataLoader(dataset, sampler=sampler)\n",
    "for inputs, labels in dataloader:\n",
    "    predictions = distrib_model(inputs.to(device))         # 正向传播\n",
    "    loss = loss_function(predictions, labels.to(device))   # 计算损失\n",
    "    loss.backward()                                        # 反向传播\n",
    "    optimizer.step()                                       # 优化\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "在运行时我们也不能简单的使用`python 文件名`来执行了，我们这里需要使用PyTorch中为我们准备好的torch.distributed.launch运行脚本。它能自动进行环境变量的设置，并使用正确的node_rank参数调用脚本。\n",
    "\n",
    "这里我们要准备一台机器作为master，所有的机器都要求能对它进行访问。因此，它需要拥有一个可以访问的IP地址（示例中为：196.168.100.100）以及一个开放的端口（示例中为：6666）。我们将使用torch.distributed.launch在第一台机器上运行脚本：\n",
    "```bash\n",
    "python -m torch.distributed.launch --nproc_per_node=2 --nnodes=2 --node_rank=0 --master_addr=\"192.168.100.100\" --master_port=6666 文件名 (--arg1 --arg2 等其他参数)\n",
    "```\n",
    "第二台主机上只需要更改 `--node_rank=0`即可"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "很有可能你在运行的时候报错，那是因为我们没有设置NCCL socket网络接口\n",
    "我们以网卡名为ens3为例，输入\n",
    "```bash\n",
    "export NCCL_SOCKET_IFNAME=ens3\n",
    "```\n",
    "ens3这个名称 可以使用ifconfig命令查看确认 "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "参数说明：\n",
    "\n",
    "--nproc_per_node ： 主机中包含的GPU总数\n",
    "\n",
    "--nnodes ： 总计的主机数\n",
    "\n",
    "--node_rank ：主机中的GPU标识\n",
    "\n",
    "其他一些参数可以查看[官方的文档](https://github.com/pytorch/pytorch/blob/master/torch/distributed/launch.py)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "torch.distributed 不仅支持nccl还支持其他的两个后端 gloo和mpi，具体的对比这里就不细说了，请查看[官方的文档](https://pytorch.org/docs/stable/distributed.html)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 4.5.3 torch.utils.checkpoint\n",
    "在我们训练时，可能会遇到（目前我还没遇到）训练集的单个样本比内存还要大根本载入不了，那我我们如何来训练呢？\n",
    "\n",
    "pytorch为我们提供了梯度检查点（gradient-checkpointing）节省计算资源，梯度检查点会将我们连续计算的元正向和元反向传播切分成片段。但由于需要增加额外的计算以减少内存需求，该方法效率会有一些下降，但是它在某些示例中有较为明显的优势，比如在长序列上训练RNN模型，这个由于复现难度较大 就不介绍了，官方文档在[这里](https://pytorch.org/docs/stable/checkpoint.html) 遇到这种情况的朋友可以查看下官方的解决方案。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.5.6"
  },
  "latex_envs": {
   "LaTeX_envs_menu_present": true,
   "autoclose": false,
   "autocomplete": true,
   "bibliofile": "biblio.bib",
   "cite_by": "apalike",
   "current_citInitial": 1,
   "eqLabelWithNumbers": true,
   "eqNumInitial": 1,
   "hotkeys": {
    "equation": "Ctrl-E",
    "itemize": "Ctrl-I"
   },
   "labels_anchors": false,
   "latex_user_defs": false,
   "report_style_numbering": false,
   "user_envs_cfg": false
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": false,
   "sideBar": true,
   "skip_h1_title": true,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": true
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
